module GenerateConstraints (genConstraints) where

import Constraints
import Control.Arrow hiding (arr)
import Control.Monad.State
import Data.List as List
import Data.Maybe (catMaybes, fromMaybe, mapMaybe)
import Info
import Obj
import qualified Set
import TypeError
import Types
import Util

-- | Will create a list of type constraints for a form.
genConstraints :: Env -> XObj -> Maybe (Ty, XObj) -> Either TypeError [Constraint]
genConstraints _ root rootSig = fmap sort (gen root)
  where
    genF xobj args body captures =
      do
        insideBodyConstraints <- gen body
        xobjType <- toEither (xobjTy xobj) (DefnMissingType xobj)
        bodyType <- toEither (xobjTy body) (ExpressionMissingType xobj)
        case xobjType of
          (FuncTy argTys retTy lifetimeTy) ->
            let bodyConstr = Constraint retTy bodyType xobj body xobj OrdDefnBody
                argConstrs = zipWith3 (\a b aObj -> Constraint a b aObj xobj xobj OrdArg) (List.map forceTy args) argTys args
                -- The constraint generated by type signatures, like (sig foo (Fn ...)):
                -- This constraint is ignored for any xobj != rootxobj (ie. (fn) let bindings)
                sigConstr =
                  if root == xobj
                    then case rootSig of
                      Just (rootSigTy, rootSigXObj) -> [Constraint rootSigTy xobjType rootSigXObj xobj xobj OrdSignatureAnnotation]
                      Nothing -> []
                    else []
                captureList :: [XObj]
                captureList = Set.toList captures
                capturesConstrs =
                  catMaybes
                    ( zipWith
                        ( \captureTy captureObj ->
                            case captureTy of
                              RefTy _ refLt ->
                                --trace ("Generated constraint between " ++ show lifetimeTy ++ " and " ++ show refLt) $
                                Just (Constraint lifetimeTy refLt captureObj xobj xobj OrdCapture)
                              _ ->
                                --trace ("Did not generate constraint for captured variable " ++ show captureObj) $
                                Nothing
                        )
                        (List.map forceTy captureList)
                        captureList
                    )
             in pure (bodyConstr : argConstrs ++ insideBodyConstraints ++ capturesConstrs ++ sigConstr)
          _ -> Left (DefnMissingType xobj) -- TODO: Better error here.
    gen xobj =
      case xobjObj xobj of
        Lst lst -> case lst of
          -- Defn
          [XObj (Defn captures) _ _, _, XObj (Arr args) _ _, body] ->
            genF xobj args body (fromMaybe Set.empty captures)
          -- Fn
          [XObj (Fn _ captures) _ _, XObj (Arr args) _ _, body] ->
            genF xobj args body captures
          -- Def
          [XObj Def _ _, _, expr] ->
            do
              insideExprConstraints <- gen expr
              xobjType <- toEither (xobjTy xobj) (DefMissingType xobj)
              exprType <- toEither (xobjTy expr) (ExpressionMissingType xobj)
              let defConstraint = Constraint xobjType exprType xobj expr xobj OrdDefExpr
                  sigConstr = case rootSig of
                    Just (rootSigTy, rootSigXObj) -> [Constraint rootSigTy xobjType rootSigXObj xobj xobj OrdSignatureAnnotation]
                    Nothing -> []
              pure (defConstraint : insideExprConstraints ++ sigConstr)
          -- Let
          [XObj Let _ _, XObj (Arr bindings) _ _, body] ->
            do
              insideBodyConstraints <- gen body
              insideBindingsConstraints <- fmap join (mapM gen bindings)
              bodyType <- toEither (xobjTy body) (ExpressionMissingType body)
              case xobjTy xobj of
                Just xobjTy' ->
                  let wholeStatementConstraint = Constraint bodyType xobjTy' body xobj xobj OrdLetBody
                      bindingsConstraints =
                        zipWith
                          ( \(symTy, exprTy) (symObj, exprObj) ->
                              Constraint symTy exprTy symObj exprObj xobj OrdLetBind
                          )
                          (List.map (forceTy *** forceTy) (pairwise bindings))
                          (pairwise bindings)
                   in pure
                        ( wholeStatementConstraint :
                          insideBodyConstraints
                            ++ bindingsConstraints
                            ++ insideBindingsConstraints
                        )
                Nothing -> Left (ExpressionMissingType xobj)
          -- If
          [XObj If _ _, expr, ifTrue, ifFalse] ->
            do
              insideConditionConstraints <- gen expr
              insideTrueConstraints <- gen ifTrue
              insideFalseConstraints <- gen ifFalse
              exprType <- toEither (xobjTy expr) (ExpressionMissingType expr)
              trueType <- toEither (xobjTy ifTrue) (ExpressionMissingType ifTrue)
              falseType <- toEither (xobjTy ifFalse) (ExpressionMissingType ifFalse)
              let expected = XObj (Sym (SymPath [] "Condition in if value") Symbol) (xobjInfo expr) (Just BoolTy)
              let conditionConstraint = Constraint exprType BoolTy expr expected xobj OrdIfCondition
                  sameReturnConstraint = Constraint trueType falseType ifTrue ifFalse xobj OrdIfReturn
               in case xobjTy xobj of
                    Just t ->
                      let wholeStatementConstraint = Constraint trueType t ifTrue xobj xobj OrdIfWhole
                       in pure
                            ( conditionConstraint :
                              sameReturnConstraint :
                              wholeStatementConstraint :
                              insideConditionConstraints
                                ++ insideTrueConstraints
                                ++ insideFalseConstraints
                            )
                    Nothing -> Left (ExpressionMissingType xobj)
          -- Match
          XObj (Match matchMode) _ _ : expr : cases ->
            do
              insideExprConstraints <- gen expr
              casesLhsConstraints <- fmap join (mapM (genConstraintsForCaseMatcher matchMode . fst) (pairwise cases))
              casesRhsConstraints <- fmap join (mapM (gen . snd) (pairwise cases))
              exprType <- toEither (xobjTy expr) (ExpressionMissingType expr)
              xobjType <- toEither (xobjTy xobj) (DefMissingType xobj)
              let -- Each case rhs should have the same return type as the whole match form:
                  mkRetConstr x@(XObj _ _ (Just t)) = Just (Constraint t xobjType x xobj xobj OrdArg) -- TODO: Ord
                  mkRetConstr _ = Nothing
                  returnConstraints = mapMaybe (\(_, rhs) -> mkRetConstr rhs) (pairwise cases)
                  -- Each case lhs should have the same type as the expression matching on
                  mkExprConstr x@(XObj _ _ (Just t)) = Just (Constraint (wrapTyInRefIfMatchingRef t) exprType x expr xobj OrdArg) -- TODO: Ord
                  mkExprConstr _ = Nothing
                  exprConstraints = mapMaybe (\(lhs, _) -> mkExprConstr lhs) (pairwise cases)
              -- Constraints for the variables in the left side of each matching case,
              -- like the 'r'/'g'/'b' in (match col (RGB r g b) ...) being constrained to Int.
              -- casesLhsConstraints = concatMap (genLhsConstraintsInCase typeEnv exprType) (map fst (pairwise cases))

              -- exprConstraint =
              --   -- | TODO: Only guess if there isn't already a type set on the expression!
              --   case guessExprType typeEnv cases of
              --     Just guessedExprTy ->
              --       let expected = XObj (Sym (SymPath [] "Expression in match-statement") Symbol)
              --                      (info expr) (Just guessedExprTy)
              --       in  [Constraint exprType guessedExprTy expr expected OrdIfCondition] -- TODO: Ord
              --     Nothing ->
              --       []

              pure
                ( insideExprConstraints
                    ++ casesLhsConstraints
                    ++ casesRhsConstraints
                    ++ returnConstraints
                    ++ exprConstraints
                )
            where
              wrapTyInRefIfMatchingRef t =
                case matchMode of
                  MatchValue -> t
                  MatchRef -> RefTy t (VarTy "whatever")
          -- While
          [XObj While _ _, expr, body] ->
            do
              insideConditionConstraints <- gen expr
              insideBodyConstraints <- gen body
              exprType <- toEither (xobjTy expr) (ExpressionMissingType expr)
              bodyType <- toEither (xobjTy body) (ExpressionMissingType body)
              let expectedCond = XObj (Sym (SymPath [] "Condition in while-expression") Symbol) (xobjInfo expr) (Just BoolTy)
                  expectedBody = XObj (Sym (SymPath [] "Body in while-expression") Symbol) (xobjInfo xobj) (Just UnitTy)
                  conditionConstraint = Constraint exprType BoolTy expr expectedCond xobj OrdWhileCondition
                  wholeStatementConstraint = Constraint bodyType UnitTy body expectedBody xobj OrdWhileBody
              pure
                ( conditionConstraint :
                  wholeStatementConstraint :
                  insideConditionConstraints ++ insideBodyConstraints
                )
          -- Do
          XObj Do _ _ : expressions ->
            case expressions of
              [] -> Left (NoStatementsInDo xobj)
              _ ->
                let lastExpr = last expressions
                 in do
                      insideExpressionsConstraints <- fmap join (mapM gen expressions)
                      xobjType <- toEither (xobjTy xobj) (DefMissingType xobj)
                      lastExprType <- toEither (xobjTy lastExpr) (ExpressionMissingType xobj)
                      let retConstraint = Constraint xobjType lastExprType xobj lastExpr xobj OrdDoReturn
                          must = XObj (Sym (SymPath [] "Statement in do-expression") Symbol) (xobjInfo xobj) (Just UnitTy)
                          mkConstr x@(XObj _ _ (Just t)) = Just (Constraint t UnitTy x must xobj OrdDoStatement)
                          mkConstr _ = Nothing
                          expressionsShouldReturnUnit = mapMaybe mkConstr (init expressions)
                      pure (retConstraint : insideExpressionsConstraints ++ expressionsShouldReturnUnit)
          -- Set!
          [XObj SetBang _ _, variable, value] ->
            do
              insideValueConstraints <- gen value
              insideVariableConstraints <- gen variable
              variableType <- toEither (xobjTy variable) (ExpressionMissingType variable)
              valueType <- toEither (xobjTy value) (ExpressionMissingType value)
              let sameTypeConstraint = Constraint variableType valueType variable value xobj OrdSetBang
              pure (sameTypeConstraint : insideValueConstraints ++ insideVariableConstraints)
          -- The
          [XObj The _ _, _, value] ->
            do
              insideValueConstraints <- gen value
              xobjType <- toEither (xobjTy xobj) (DefMissingType xobj)
              valueType <- toEither (xobjTy value) (DefMissingType value)
              let theTheConstraint = Constraint xobjType valueType xobj value xobj OrdThe
              pure (theTheConstraint : insideValueConstraints)
          -- Ref
          [XObj Ref _ _, value] ->
            gen value
          -- Deref
          [XObj Deref _ _, value] ->
            do
              insideValueConstraints <- gen value
              xobjType <- toEither (xobjTy xobj) (ExpressionMissingType xobj)
              valueType <- toEither (xobjTy value) (ExpressionMissingType value)
              let lt = VarTy (makeTypeVariableNameFromInfo (xobjInfo xobj))
              let theTheConstraint = Constraint (RefTy xobjType lt) valueType xobj value xobj OrdDeref
              pure (theTheConstraint : insideValueConstraints)
          -- Break
          [XObj Break _ _] ->
            pure []
          -- Function application
          func : args ->
            do
              funcConstraints <- gen func
              variablesConstraints <- fmap join (mapM gen args)
              funcTy <- toEither (xobjTy func) (ExpressionMissingType func)
              case funcTy of
                (FuncTy argTys retTy _) ->
                  if length args /= length argTys
                    then Left (WrongArgCount func (length argTys) (length args))
                    else
                      let expected t n =
                            XObj
                              (Sym (SymPath [] ("Expected " ++ enumerate n ++ " argument to '" ++ getName func ++ "'")) Symbol)
                              (xobjInfo func)
                              (Just t)
                          argConstraints =
                            zipWith4
                              (\a t aObj n -> Constraint a t aObj (expected t n) xobj OrdFuncAppArg)
                              (List.map forceTy args)
                              argTys
                              args
                              [0 ..]
                       in case xobjTy xobj of
                            Just xobjTy' ->
                              let retConstraint = Constraint xobjTy' retTy xobj func xobj OrdFuncAppRet
                               in pure (retConstraint : funcConstraints ++ argConstraints ++ variablesConstraints)
                            Nothing -> Left (ExpressionMissingType xobj)
                funcVarTy@(VarTy _) ->
                  let fabricatedFunctionType = FuncTy (List.map forceTy args) (forceTy xobj) (VarTy "what?!")
                      expected = XObj (Sym (SymPath [] ("Calling '" ++ getName func ++ "'")) Symbol) (xobjInfo func) Nothing
                      wholeTypeConstraint = Constraint funcVarTy fabricatedFunctionType func expected xobj OrdFuncAppVarTy
                   in pure (wholeTypeConstraint : funcConstraints ++ variablesConstraints)
                _ -> Left (NotAFunction func)
          -- Empty list
          [] -> Right []
        (Arr arr) ->
          case arr of
            [] -> Right []
            x : xs -> do
              insideExprConstraints <- fmap join (mapM gen arr)
              case xobjTy x of
                Nothing -> Left (ExpressionMissingType x)
                Just headTy ->
                  let genObj o n =
                        XObj
                          (Sym (SymPath [] ("Whereas the " ++ enumerate n ++ " element in the array is " ++ show (getPath o))) Symbol)
                          (xobjInfo o)
                          (xobjTy o)
                      headObj =
                        XObj
                          (Sym (SymPath [] ("I inferred the type of the array from its first element " ++ show (getPath x))) Symbol)
                          (xobjInfo x)
                          (Just headTy)
                   in case xobjTy xobj of
                        Just (StructTy (ConcreteNameTy (SymPath [] "Array")) [t]) ->
                          let betweenExprConstraints = zipWith (\o n -> Constraint headTy (forceTy o) headObj (genObj o n) xobj OrdArrBetween) xs [1 ..]
                              headConstraint = Constraint headTy t headObj (genObj x 1) xobj OrdArrHead
                           in pure (headConstraint : insideExprConstraints ++ betweenExprConstraints)
                        _ -> Left (ExpressionMissingType xobj) -- TODO: better error here.
                        -- THIS CODE IS VERY MUCH A DUPLICATION OF THE 'ARR' CODE FROM ABOVE:
        (StaticArr arr) ->
          case arr of
            [] -> Right []
            x : xs -> do
              insideExprConstraints <- fmap join (mapM gen arr)
              case xobjTy x of
                Nothing -> Left (ExpressionMissingType x)
                Just headTy ->
                  let genObj o n =
                        XObj
                          (Sym (SymPath [] ("Whereas the " ++ enumerate n ++ " element in the array is " ++ show (getPath o))) Symbol)
                          (xobjInfo o)
                          (xobjTy o)
                      headObj =
                        XObj
                          (Sym (SymPath [] ("I inferred the type of the static array from its first element " ++ show (getPath x))) Symbol)
                          (xobjInfo x)
                          (Just headTy)
                   in case xobjTy xobj of
                        Just (RefTy (StructTy (ConcreteNameTy (SymPath [] "StaticArray")) [t]) _) ->
                          let betweenExprConstraints = zipWith (\o n -> Constraint headTy (forceTy o) headObj (genObj o n) xobj OrdArrBetween) xs [1 ..]
                              headConstraint = Constraint headTy t headObj (genObj x 1) xobj OrdArrHead
                           in pure (headConstraint : insideExprConstraints ++ betweenExprConstraints)
                        _ -> Left (ExpressionMissingType xobj) -- TODO: Better error here.
        _ -> Right []

genConstraintsForCaseMatcher :: MatchMode -> XObj -> Either TypeError [Constraint]
genConstraintsForCaseMatcher matchMode = gen
  where
    gen xobj@(XObj (Lst (caseName : variables)) _ _) =
      do
        caseNameConstraints <- gen caseName
        variablesConstraints <- fmap join (mapM gen variables)
        caseNameTy <- toEither (xobjTy caseName) (ExpressionMissingType caseName)
        case caseNameTy of
          (FuncTy argTys retTy _) ->
            if length variables /= length argTys
              then Left (WrongArgCount caseName (length argTys) (length variables)) -- TODO: This could be another error since this isn't an actual function call.
              else
                let expected t n = XObj (Sym (SymPath [] ("Expected " ++ enumerate n ++ " argument to '" ++ getName caseName ++ "'")) Symbol) (xobjInfo caseName) (Just t)
                    argConstraints =
                      zipWith4
                        (\a t aObj n -> Constraint a t aObj (expected t n) xobj OrdFuncAppArg)
                        (List.map forceTy variables)
                        (zipWith refWrapper variables argTys)
                        variables
                        [0 ..]
                 in case xobjTy xobj of
                      Nothing -> Left (ExpressionMissingType xobj)
                      Just t ->
                        let retConstraint = Constraint t retTy xobj caseName xobj OrdFuncAppRet
                         in pure (retConstraint : caseNameConstraints ++ argConstraints ++ variablesConstraints)
          funcVarTy@(VarTy _) ->
            let fabricatedFunctionType = FuncTy (List.map forceTy variables) (forceTy xobj) (VarTy "what?!") -- TODO: Fix
                expected = XObj (Sym (SymPath [] ("Matchin on '" ++ getName caseName ++ "'")) Symbol) (xobjInfo caseName) Nothing
                wholeTypeConstraint = Constraint funcVarTy fabricatedFunctionType caseName expected xobj OrdFuncAppVarTy
             in pure (wholeTypeConstraint : caseNameConstraints ++ variablesConstraints)
          _ -> Left (NotAFunction caseName) -- TODO: This error could be more specific too, since it's not an actual function call.
    gen _ = pure []
    refWrapper :: XObj -> Ty -> Ty
    refWrapper (XObj (Sym _ _) _ _) wrapThisType = wrapInRefTyIfMatchRef matchMode wrapThisType
    refWrapper _ t = t
